import numpy as np
from scipy.optimize import fsolve
from util import undersampled

def TS(design_mean, design_var, design_used, n0):
    total_sample = np.sum(design_used)
    undersample = undersampled(len(design_mean), total_sample, design_used)

    if np.min(design_used) < n0:
        next_alternative = np.argmin(design_used)

    elif len(undersample) > 0:
        extracted_values = np.array(design_used[i] for i in undersample)
        min_idx = np.argmin(extracted_values)
        next_alternative = undersample[min_idx]

    else:
        initial_omega = np.ones(len(design_mean)) / len(design_mean)
        initial_common_value = 1.0
        initial_guess = np.append(initial_omega, initial_common_value)
        solution = fsolve(equations, initial_guess, args=(design_mean, design_var))
        omega = solution[:len(design_mean)]
        omega = omega / np.sum(omega)
        next_alternative = np.argmin(design_used - np.sum(design_used) * omega)

    return next_alternative

def equations(vars, design_mean, design_var):
    K = design_mean.shape[0]
    best_id = np.argmax(design_mean)
    mu_best = np.max(design_mean)
    omega = vars[: K]
    common_value = vars[K]

    eq = []
    K_prime = [a for a in range(len(design_mean)) if a != best_id]
    for i in K_prime:
        term = (design_mean[i] - mu_best)**2 / (2 * ((design_var[best_id] / omega[best_id]) + (design_var[i] / omega[i])))
        eq.append(term - common_value)

    sum_K = np.sum([(design_var[best_id] / (design_var[i])) * omega[i]**2 for i in K_prime])
    eq.append(np.sqrt(sum_K) - omega[best_id])
    eq.append(np.sum(omega) - 1)

    return eq

